"""
Phase 2: Gravity model estimation with bilateral demographic distance.

Models:
  2a. Baseline gravity (no demographics)
  2b. Add bilateral demographic distance (ΔZ_1, ΔZ_2, ΔZ_3)
  2c. KAOPEN interactions (ΔZ × KAOPEN_j)
  2d. Separate portfolio vs FDI
  2e. Price channel controls (interest rate diff, REER)

Output: gravity_bilateral/output/tables/gravity_results.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys

sys.path.insert(0, str(Path("/mnt/c/demographics_capital_flows/gravity_bilateral")))
from src.model import PanelGLS

# Paths
BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral")
PROCESSED_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"


def load_panel():
    """Load bilateral panel and prepare estimation sample."""
    df = pd.read_csv(PROCESSED_DIR / "bilateral_panel.csv")
    print(f"Loaded bilateral panel: {len(df):,} obs, {df['pair_id'].nunique():,} pairs")

    # Year dummies (using year as integer for PanelGLS time_ids)
    # Create year dummies manually for inclusion as regressors
    years = sorted(df['year'].dropna().unique())
    base_year = min(years)
    for y in years[1:]:  # drop first year as reference
        df[f'yr_{int(y)}'] = (df['year'] == y).astype(int)

    return df, years


def estimate_model(df, dep_var, regressors, model_name, year_dummies=True):
    """Estimate a gravity model specification and return results."""
    years = sorted(df['year'].dropna().unique())
    yr_cols = [f'yr_{int(y)}' for y in years[1:]] if year_dummies else []
    yr_cols = [c for c in yr_cols if c in df.columns]

    all_vars = regressors + yr_cols
    est = df.dropna(subset=[dep_var] + all_vars + ['pair_id', 'year']).copy()

    if len(est) < 100:
        print(f"  {model_name}: insufficient obs ({len(est)})")
        return None, None

    print(f"\n{'=' * 70}")
    print(f"  {model_name}")
    print(f"  Dep var: {dep_var}")
    print(f"  N = {len(est):,}, Pairs = {est['pair_id'].nunique():,}")
    print(f"  Years: {est['year'].min():.0f}-{est['year'].max():.0f}")
    print(f"{'=' * 70}")

    y = est[dep_var].values
    X = est[all_vars].values

    gls = PanelGLS()
    gls.fit(y, X, est['pair_id'].values, est['year'].values.astype(int))

    # Print main regressors (not year dummies)
    print(f"  R² = {gls.r_squared:.4f}, ρ = {gls.rho:.3f}")
    print(f"  {'Variable':<30} {'Coef':>10} {'SE':>10} {'p-val':>8}")
    print(f"  {'-' * 60}")
    for i, v in enumerate(regressors):
        sig = '***' if gls.pvalues[i] < 0.01 else '**' if gls.pvalues[i] < 0.05 else '*' if gls.pvalues[i] < 0.1 else ''
        print(f"  {v:<30} {gls.beta[i]:>10.4f} {gls.se[i]:>10.4f} {gls.pvalues[i]:>8.4f} {sig}")

    # Build results dict
    results = []
    for i, v in enumerate(regressors):
        results.append({
            'model': model_name,
            'variable': v,
            'coefficient': gls.beta[i],
            'std_error': gls.se[i],
            't_stat': gls.tvalues[i],
            'p_value': gls.pvalues[i],
        })
    results.append({
        'model': model_name,
        'variable': '_R_squared',
        'coefficient': gls.r_squared,
        'std_error': np.nan,
        't_stat': np.nan,
        'p_value': np.nan,
    })
    results.append({
        'model': model_name,
        'variable': '_N_obs',
        'coefficient': gls.n_obs,
        'std_error': np.nan,
        't_stat': np.nan,
        'p_value': np.nan,
    })
    results.append({
        'model': model_name,
        'variable': '_N_pairs',
        'coefficient': gls.n_countries,
        'std_error': np.nan,
        't_stat': np.nan,
        'p_value': np.nan,
    })
    results.append({
        'model': model_name,
        'variable': '_rho',
        'coefficient': gls.rho,
        'std_error': np.nan,
        't_stat': np.nan,
        'p_value': np.nan,
    })

    return gls, pd.DataFrame(results)


def main():
    print("=" * 70)
    print("PHASE 2: GRAVITY MODEL ESTIMATION")
    print("=" * 70)

    df, years = load_panel()
    all_results = []

    # Determine which dependent variable to use
    # Prefer portfolio_total, fall back to what's available
    dep_var = None
    for candidate in ['log_portfolio_total', 'log_portfolio_equity', 'log_fdi_outward']:
        if candidate in df.columns and df[candidate].notna().sum() > 500:
            dep_var = candidate
            break

    if dep_var is None:
        print("ERROR: No usable dependent variable found!")
        print("Available flow columns:")
        for c in df.columns:
            if 'log_' in c or 'portfolio' in c or 'fdi' in c:
                print(f"  {c}: {df[c].notna().sum():,} non-null")
        return

    print(f"\nPrimary dependent variable: {dep_var}")
    print(f"  Non-null: {df[dep_var].notna().sum():,}")

    # ===================================================================
    # 2a. Baseline gravity (no demographics)
    # ===================================================================
    gravity_vars = []
    for v in ['log_dist', 'contiguity', 'common_lang_official', 'colonial_ties', 'log_gdp_product']:
        if v in df.columns and df[v].notna().sum() > 500:
            gravity_vars.append(v)

    if not gravity_vars:
        print("ERROR: No gravity variables available!")
        return

    print(f"\nGravity variables: {gravity_vars}")

    gls_base, res_base = estimate_model(
        df, dep_var, gravity_vars,
        "2a: Baseline Gravity"
    )
    if res_base is not None:
        all_results.append(res_base)

    # ===================================================================
    # 2b. Add bilateral demographic distance
    # ===================================================================
    demo_vars = [v for v in ['dZ_1', 'dZ_2', 'dZ_3'] if v in df.columns]
    if demo_vars:
        gls_demo, res_demo = estimate_model(
            df, dep_var, gravity_vars + demo_vars,
            "2b: Gravity + Demographics"
        )
        if res_demo is not None:
            all_results.append(res_demo)

    # ===================================================================
    # 2c. KAOPEN interactions
    # ===================================================================
    kaopen_vars = [v for v in ['dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j', 'dZ_3_x_kaopen_j']
                   if v in df.columns]
    kaopen_level = ['kaopen_j'] if 'kaopen_j' in df.columns else []

    if demo_vars and kaopen_vars:
        gls_kaopen, res_kaopen = estimate_model(
            df, dep_var, gravity_vars + demo_vars + kaopen_level + kaopen_vars,
            "2c: Gravity + Demographics + KAOPEN interactions"
        )
        if res_kaopen is not None:
            all_results.append(res_kaopen)

    # ===================================================================
    # 2d. Separate portfolio equity, debt, and FDI
    # ===================================================================
    for flow_type, flow_dep in [('Portfolio Equity', 'log_portfolio_equity'),
                                 ('Portfolio Debt', 'log_portfolio_debt'),
                                 ('FDI Outward', 'log_fdi_outward')]:
        if flow_dep not in df.columns or df[flow_dep].notna().sum() < 200:
            print(f"\n  Skipping {flow_type}: insufficient data")
            continue

        gls_flow, res_flow = estimate_model(
            df, flow_dep, gravity_vars + demo_vars,
            f"2d: {flow_type}"
        )
        if res_flow is not None:
            all_results.append(res_flow)

    # ===================================================================
    # 2e. Price channel controls
    # ===================================================================
    price_vars = []
    if 'rate_diff_ij' in df.columns and df['rate_diff_ij'].notna().sum() > 200:
        price_vars.append('rate_diff_ij')

    if demo_vars and price_vars:
        gls_price, res_price = estimate_model(
            df, dep_var, gravity_vars + demo_vars + price_vars,
            "2e: Gravity + Demographics + Price Controls"
        )
        if res_price is not None:
            all_results.append(res_price)

        # Key test: compare ΔZ coefficients with and without price controls
        if gls_demo is not None and gls_price is not None:
            print(f"\n{'=' * 70}")
            print("  KEY TEST: Do ΔZ coefficients survive price controls?")
            print(f"{'=' * 70}")
            print(f"  {'Variable':<20} {'Without prices':>15} {'With prices':>15} {'Change':>10}")
            demo_results_no_price = res_demo[res_demo['variable'].isin(demo_vars)]
            demo_results_w_price = res_price[res_price['variable'].isin(demo_vars)]
            for _, row_np in demo_results_no_price.iterrows():
                v = row_np['variable']
                row_wp = demo_results_w_price[demo_results_w_price['variable'] == v]
                if len(row_wp) > 0:
                    coef_np = row_np['coefficient']
                    coef_wp = row_wp.iloc[0]['coefficient']
                    pct_change = (coef_wp - coef_np) / abs(coef_np) * 100 if coef_np != 0 else np.nan
                    print(f"  {v:<20} {coef_np:>15.4f} {coef_wp:>15.4f} {pct_change:>9.1f}%")

    # ===================================================================
    # 2f. Two-stage Carvalho bilateral: fitted rate differential
    # ===================================================================
    # S1 coefficients from multilateral model (Z-only → real_bond_10y_diff)
    # Estimated on 23 OECD countries, 689 obs (see rate_channel_tests.csv)
    S1_COEFS = {
        'Z_1': 16.319743982966457,
        'Z_2': -2.0745621076630774,
        'Z_3': 0.07178776182359387,
    }

    # Construct fitted yields for each country in each pair
    z_cols_i = ['Z_1_i', 'Z_2_i', 'Z_3_i']
    z_cols_j = ['Z_1_j', 'Z_2_j', 'Z_3_j']
    if all(c in df.columns for c in z_cols_i + z_cols_j):
        # Fitted yield = S1 coefficients × Z variables
        df['fitted_yield_i'] = (
            S1_COEFS['Z_1'] * df['Z_1_i']
            + S1_COEFS['Z_2'] * df['Z_2_i']
            + S1_COEFS['Z_3'] * df['Z_3_i']
        )
        df['fitted_yield_j'] = (
            S1_COEFS['Z_1'] * df['Z_1_j']
            + S1_COEFS['Z_2'] * df['Z_2_j']
            + S1_COEFS['Z_3'] * df['Z_3_j']
        )
        df['fitted_rate_diff_ij'] = df['fitted_yield_i'] - df['fitted_yield_j']

        print(f"\n  Fitted rate differential: mean={df['fitted_rate_diff_ij'].mean():.3f}, "
              f"std={df['fitted_rate_diff_ij'].std():.3f}")

        # Model 2f-i: Rate-mediated demographics only
        gls_rate, res_rate = estimate_model(
            df, dep_var, gravity_vars + ['fitted_rate_diff_ij'],
            "2f: Gravity + Fitted Rate Differential (Carvalho)"
        )
        if res_rate is not None:
            all_results.append(res_rate)

        # Model 2f-ii: Actual rate diff + fitted rate diff (where both available)
        # This shows whether observed rates add to the fitted demographic component
        if 'rate_diff_ij' in df.columns:
            both_rate_vars = ['fitted_rate_diff_ij', 'rate_diff_ij']
            gls_both, res_both = estimate_model(
                df, dep_var, gravity_vars + demo_vars + both_rate_vars,
                "2f-ii: Full Model + Fitted & Actual Rates"
            )
            if res_both is not None:
                all_results.append(res_both)

        # Mediation decomposition
        if gls_base is not None and gls_demo is not None and gls_rate is not None:
            r2_base = gls_base.r_squared
            r2_demo = gls_demo.r_squared
            r2_rate = gls_rate.r_squared

            total_demo_contribution = r2_demo - r2_base
            rate_contribution = r2_rate - r2_base
            direct_contribution = total_demo_contribution - rate_contribution

            print(f"\n{'=' * 70}")
            print("  MEDIATION DECOMPOSITION: Rate Channel vs Direct")
            print(f"{'=' * 70}")
            print(f"  Baseline R² (gravity only):        {r2_base:.4f}")
            print(f"  R² with full ΔZ:                   {r2_demo:.4f}  (+{total_demo_contribution:.4f})")
            print(f"  R² with fitted Δr̂ only:            {r2_rate:.4f}  (+{rate_contribution:.4f})")
            print(f"")
            if total_demo_contribution > 0:
                rate_share = rate_contribution / total_demo_contribution * 100
                direct_share = direct_contribution / total_demo_contribution * 100
                print(f"  Rate-mediated share:               {rate_share:.1f}%")
                print(f"  Direct/other channel share:        {direct_share:.1f}%")
                print(f"")
                print(f"  (cf. clearing channels 5.1: rates=9%, direct=89%)")

            # Save decomposition
            decomp = pd.DataFrame([
                {'metric': 'R2_baseline', 'value': r2_base},
                {'metric': 'R2_full_demographics', 'value': r2_demo},
                {'metric': 'R2_fitted_rate_only', 'value': r2_rate},
                {'metric': 'total_demo_R2_improvement', 'value': total_demo_contribution},
                {'metric': 'rate_channel_R2_improvement', 'value': rate_contribution},
                {'metric': 'direct_channel_R2_improvement', 'value': direct_contribution},
                {'metric': 'rate_channel_share_pct', 'value': rate_share if total_demo_contribution > 0 else np.nan},
                {'metric': 'direct_channel_share_pct', 'value': direct_share if total_demo_contribution > 0 else np.nan},
            ])
            decomp.to_csv(OUTPUT_DIR / "mediation_decomposition.csv", index=False)
            print(f"  Saved: {OUTPUT_DIR / 'mediation_decomposition.csv'}")

    # ===================================================================
    # Combine and save results
    # ===================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        outfile = OUTPUT_DIR / "gravity_results.csv"
        results_df.to_csv(outfile, index=False)
        print(f"\n  Saved all results: {outfile}")

        # Summary comparison table
        print(f"\n{'=' * 70}")
        print("MODEL COMPARISON SUMMARY")
        print(f"{'=' * 70}")
        models = results_df['model'].unique()
        for m in models:
            mdf = results_df[results_df['model'] == m]
            r2 = mdf.loc[mdf['variable'] == '_R_squared', 'coefficient'].values
            n = mdf.loc[mdf['variable'] == '_N_obs', 'coefficient'].values
            rho = mdf.loc[mdf['variable'] == '_rho', 'coefficient'].values
            r2_str = f"{r2[0]:.4f}" if len(r2) > 0 else "N/A"
            n_str = f"{n[0]:,.0f}" if len(n) > 0 else "N/A"
            rho_str = f"{rho[0]:.3f}" if len(rho) > 0 else "N/A"
            print(f"  {m:<50} R²={r2_str}  N={n_str}  ρ={rho_str}")

    return results_df if all_results else None


if __name__ == "__main__":
    results = main()
